require(colorout)
require(sqldf)
require(arm)
require(rstan)
require(foreach)
require(doMC)
require(RPostgreSQL)

remove(list=objects())
options(digits=2, scipen=9, width=110, java.parameters = "-Xrs")

################################################################################################################
# helper functions

logit <- function(x)
  return(-1 * log(1/pmin(0.995, pmax(0.005, x)) - 1))

AbsError <- function(delta, y, n, yhat) {
  y <- pmin(0.995, pmax(0.005, y))
  abs((sum(invlogit(logit(y) + delta) * n) / sum(n) - yhat))    
}

FindDelta <- function(y, n, yhat)
  optimize(AbsError, interval=c(-5,5), y, n, yhat)$minimum

username <- system("echo $USER", intern=TRUE)
conn <- dbConnect(dbDriver("PostgreSQL"), user=username, password="", host="localhost", port=5432, dbname=username)
SQL <- function(x) 
  dbGetQuery(conn, x)
SQL_EXECUTE <- function(x)
  system(paste0("psql -d $USER -U $USER -c '", x, "';"))

################################################################################################################

n_samples <- 100

# load reference data
ref <- data.table::fread("data/yg_pamrp_fips2012results.txt", data.table=FALSE)
ref$fips <- stringr::str_pad(ref$fips, width=5, pad="0")
ref$y <- ref$d2012 / (ref$d2012 + ref$r2012)

# make sure we have the same set of counties across datasets
fips <- sort(SQL("select distinct(fips) from v1 where voted2012 = 1")[,1])
fips <- fips[fips %in% ref$fips]
ref <- ref[ref$fips %in% fips,]
ref <- ref[order(ref$fips),]
target <- ref$y
names(target) <- ref$fips

# find deltas
results <- NULL
for (i_sample in 1:n_samples) {
  message("Finding county deltas, ", i_sample, " of ", n_samples)
  dat <- SQL(paste0("
    select 
    fips, 
    round(1 / (1 + exp(-1 * (yhat", i_sample, "))), 2) as y, count(*) as n
    from v1
    where voted2012 = 1
    group by 1, 2 
    order by 1, 2
  "))
  dat$fips <- stringr::str_pad(dat$fips, width=5, pad="0")
  dat <- dat[dat$fips %in% fips,]

  y <- split(dat$y, dat$fips)
  n <- split(dat$n, dat$fips)
  results[[i_sample]] <- sapply(names(y), function(i) 
    FindDelta(y=y[[i]], n=n[[i]], yhat=target[i]))
}

deltas <- data.frame(
  fips=fips, 
  sapply(results, function(i) i), 
  stringsAsFactors=FALSE)
colnames(deltas) <- c("fips", paste0("d", 1:n_samples))
rownames(deltas) <- NULL

# upload to postgres
SQL_EXECUTE("drop table if exists deltas;")
dbWriteTable(conn, "deltas", deltas, row.names=TRUE, append=FALSE)

######################################################################
# score

SQL_EXECUTE(paste0("
  drop table if exists final cascade; 
  create table final as
  select randomid, a.state, 
  cast(NULL as numeric) as yhat, ", 
  paste0("
    round(cast(1 / (1 + exp(-1 * (
      coalesce(b.d", 1:n_samples, ", 0) + 
      yhat", 1:n_samples, "
    ))) as numeric), 3) as yhat", 1:n_samples, "
  ", collapse=",\n"), "
  from v1 a 
  left join deltas b on (a.fips=b.fips); 
"))

SQL_EXECUTE(paste0("
  update final set yhat = round((", paste0("yhat", 1:n_samples, collapse=" + "), ") / ", n_samples, ", 3);
"))
